#!/usr/bin/env python
# coding: utf-8

# # Prerequises

# In[13]:


print("Loading dependencies", flush=True)
import numpy as np
import matplotlib.pyplot as plt
import mma
from classif_helper import *
import gudhi as gd
from sklearn.neighbors import KernelDensity
import pickle

from sys import argv
kmin = int(argv[1])
kmax = int(argv[2])
nsamples = int(argv[3])
# # Dataset generation

# In[2]:


n_pts = 5_000
np.random.seed(0)
dataset = np.block([
    [np.array(mma.noisy_annulus(0.4,0.45,n1=(int)(n_pts*2/10),n2=0, center = [1.2,1.3]))],
    [np.array(mma.noisy_annulus(0.3,0.31,n1=(int)(n_pts*2/10),n2=0, center = [0.2,-1]))],
    [np.array(mma.noisy_annulus(0.2,0.201,n1=(int)(n_pts*2/10),n2=0, center = [-1,0.5]))],
    [np.random.uniform(low=-2,high=2,size=((int)(n_pts*2/10),2))]
])
np.random.shuffle(dataset)


# # Filtrations

# In[5]:


params = {
	"n_jobs":int(cpu_count()),
	"kmin":kmin,"kmax":kmax,"nsamples":nsamples,
	"precision":0.01,
	"degree":1,"resolution":[50,50],
	"kde_bandwidth":0.05,
	"box":[[-0.1,-1],[1,2]],
	"kde_kernel": "gaussian",
	"normalize":1,
	"bandwidth":0.1,
	"ps":[0,0.5,1,2, np.inf],"threshold":10,
	"flatten":False,
}


# In[6]:


def get_bf(k, **params):
    X = dataset[1:k]
    simplextree = gd.RipsComplex(points=X, max_edge_length=1).create_simplex_tree()

    # simplex_tree = ripscplx.create_simplex_tree(max_dimension=2)
    kde = KernelDensity(kernel='gaussian', bandwidth=0.05).fit(X)
    density = kde.score_samples(X)
    filtration_density = -np.array(density)
    simplextree = mma.SimplexTreeMulti(simplextree, num_parameters=2)
    simplextree.fill_lowerstar(filtration_density, parameter=1)
    simplextree.collapse_edges(num=100)
    simplextree.expansion(2)
    return simplextree
def mod_dump(k:int):
    simplextree = get_bf(k, **params)
    return simplextree.persistence_approximation(**params).dump()


# # Computation

# In[16]:


start = params["kmin"]
stop = params["kmax"]
num = params["nsamples"]


# In[17]:


iterator = np.linspace(start=start, stop=stop, num=num, dtype=int)
with open(f"modules/synthetic2/iterator_{start}_{stop}_{num}.np", "wb") as f:
    np.save(f,iterator) 

# In[21]:
print("Computing modules...", flush=True)

compute_mods(iterator, get_bf, dump=True, save=f"modules/synthetic2/module_{start}_{stop}_{num}_", **params)
# print("Saving modules...")

# with open(f"modules/cv_synthetic2_module_{start}_{stop}_{num}.pkl", 'wb') as file:
#     pickle.dump([approximation_modules, params], file)
print("Done !")




